import torch

# safe_softmax
def safe_softmax(a:torch.tensor):
    rowmax = torch.max(a, dim=-1, keepdim=True)[0]
    p = torch.exp(a - rowmax)
    l = torch.sum(p, dim=-1, keepdim=True)
    return p / l

